# -*- coding: utf-8 -*-
"""
# @FileName:     get_dataloader.py
# @AuthorName:   Sanqi Lu (Lingwei Dang)
# @Institution:  SCUT, Guangzhou, China
# @EmailAddress: lenvondang@163.com
# @CreateTime:   2024/12/22 12:31
"""
import torchvision
from torchvision.transforms import ToTensor, Compose, Lambda
from torch.utils.data.dataloader import DataLoader

def get_dataloader(batch_size: int):
    transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])
    # dataset = torchvision.datasets.MNIST(root='/Volumes/lenovo2t/datas/MNIST', transform=transform)
    dataset = torchvision.datasets.MNIST(root=r'E:\datas\MNIST', transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)